# Copyright 2020 Huawei Technologies Co., Ltd

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import math
import random
import numpy as np
import cv2
from pycocotools.coco import COCO as ReadJson

import mindspore.dataset as de

from src.config import JointType, params

cv2.setNumThreads(0)

class txtdataset():
    def __init__(self, train, imgpath, maskpath, insize, mode='train', n_samples=None):
        self.train = train
        self.mode = mode
        self.imgpath = imgpath
        self.maskpath = maskpath
        self.insize = insize
        self.maxtime = 0
        self.catIds = train.getCatIds(catNms=['person'])
        self.imgIds = sorted(train.getImgIds(catIds=self.catIds))
        if self.mode == 'train':
            self.clean_imgIds()
        if self.mode in ['val', 'eval'] and n_samples is not None:
            self.imgIds = random.sample(self.imgIds, n_samples)
        print('{} images: {}'.format(mode, len(self)))


    def __len__(self):
        return len(self.imgIds)

    def clean_imgIds(self):
        print("cleaning imgids")

        for img_id in self.imgIds.copy():
            annotations = None
            anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)

            # annotation for that image
            if anno_ids:
                annotations_for_img = self.train.loadAnns(anno_ids)

                person_cnt = 0
                valid_annotations_for_img = []
                for annotation in annotations_for_img:
                    # if too few keypoints or too small
                    if annotation['num_keypoints'] >= params['min_keypoints'] and \
                            annotation['area'] > params['min_area']:
                        person_cnt += 1
                        valid_annotations_for_img.append(annotation)

                # if person annotation
                if person_cnt > 0:
                    annotations = valid_annotations_for_img
            if annotations is None:
                #print(img_id,'is removed')
                self.imgIds.remove(img_id)

    def overlay_paf(self, img, paf):
        hue = ((np.arctan2(paf[1], paf[0]) / np.pi) / -2 + 0.5)
        saturation = np.sqrt(paf[0] ** 2 + paf[1] ** 2)
        saturation[saturation > 1.0] = 1.0
        value = saturation.copy()
        hsv_paf = np.vstack((hue[np.newaxis], saturation[np.newaxis], value[np.newaxis])).transpose(1, 2, 0)
        rgb_paf = cv2.cvtColor((hsv_paf * 255).astype(np.uint8), cv2.COLOR_HSV2BGR)
        img = cv2.addWeighted(img, 0.6, rgb_paf, 0.4, 0)
        return img

    def overlay_pafs(self, img, pafs):
        mix_paf = np.zeros((2,) + img.shape[:-1])
        paf_flags = np.zeros(mix_paf.shape) # for constant paf

        for paf in pafs.reshape((int(pafs.shape[0]/2), 2,) + pafs.shape[1:]):
            paf_flags = paf != 0
            paf_flags += np.broadcast_to(paf_flags[0] | paf_flags[1], paf.shape)
            mix_paf += paf

        mix_paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
        img = self.overlay_paf(img, mix_paf)
        return img

    def overlay_heatmap(self, img, heatmap):
        rgb_heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
        img = cv2.addWeighted(img, 0.6, rgb_heatmap, 0.4, 0)
        return img

    def overlay_ignore_mask(self, img, ignore_mask):
        img = img * np.repeat((ignore_mask == 0).astype(np.uint8)[:, :, None], 3, axis=2)
        return img

    # -------------------- augment code --------------------------------
    def get_pose_bboxes(self, poses):
        pose_bboxes = []
        for pose in poses:
            x1 = pose[pose[:, 2] > 0][:, 0].min()
            y1 = pose[pose[:, 2] > 0][:, 1].min()
            x2 = pose[pose[:, 2] > 0][:, 0].max()
            y2 = pose[pose[:, 2] > 0][:, 1].max()
            pose_bboxes.append([x1, y1, x2, y2])
        pose_bboxes = np.array(pose_bboxes)
        return pose_bboxes

    def resize_data(self, img, ignore_mask, poses, shape):
        """resize img, mask and annotations"""
        img_h, img_w, _ = img.shape

        resized_img = cv2.resize(img, shape)
        ignore_mask = cv2.resize(ignore_mask.astype(np.uint8), shape).astype('bool')
        poses[:, :, :2] = (poses[:, :, :2] * np.array(shape) / np.array((img_w, img_h)))
        return resized_img, ignore_mask, poses

    def random_resize_img(self, img, ignore_mask, poses):
        h, w, _ = img.shape
        joint_bboxes = self.get_pose_bboxes(poses)
        bbox_sizes = ((joint_bboxes[:, 2:] - joint_bboxes[:, :2] + 1) ** 2).sum(axis=1) ** 0.5

        min_scale = params['min_box_size'] / bbox_sizes.min()
        max_scale = params['max_box_size'] / bbox_sizes.max()

        min_scale = min(max(min_scale, params['min_scale']), 1)
        max_scale = min(max(max_scale, 1), params['max_scale'])

        scale = float((max_scale - min_scale) * random.random() + min_scale)
        #scale = random.random()*1.5+0.5
        shape = (round(w * scale), round(h * scale))

        resized_img, resized_mask, resized_poses = self.resize_data(img, ignore_mask, poses, shape)
        return resized_img, resized_mask, resized_poses

    def random_rotate_img(self, img, mask, poses):
        h, w, _ = img.shape
        # degree = (random.random() - 0.5) * 2 * params['max_rotate_degree']
        degree = np.random.randn() / 3 * params['max_rotate_degree']
        rad = degree * math.pi / 180
        center = (w / 2, h / 2)
        R = cv2.getRotationMatrix2D(center, degree, 1)
        bbox = (w * abs(math.cos(rad)) + h * abs(math.sin(rad)), w * abs(math.sin(rad)) + h * abs(math.cos(rad)))
        R[0, 2] += bbox[0] / 2 - center[0]
        R[1, 2] += bbox[1] / 2 - center[1]
        rotate_img = cv2.warpAffine(img, R, (int(bbox[0]+0.5), int(bbox[1]+0.5)), flags=cv2.INTER_CUBIC,
                                    borderMode=cv2.BORDER_CONSTANT, borderValue=[127.5, 127.5, 127.5])
        rotate_mask = cv2.warpAffine(mask.astype('uint8')*255, R, (int(bbox[0]+0.5), int(bbox[1]+0.5))) > 0

        tmp_poses = np.ones_like(poses)
        tmp_poses[:, :, :2] = poses[:, :, :2].copy()
        tmp_rotate_poses = np.dot(tmp_poses, R.T)  # apply rotation matrix to the poses
        rotate_poses = poses.copy()  # to keep visibility flag
        rotate_poses[:, :, :2] = tmp_rotate_poses
        return rotate_img, rotate_mask, rotate_poses

    def random_crop_img(self, img, ignore_mask, poses):
        h, w, _ = img.shape
        insize = self.insize
        joint_bboxes = self.get_pose_bboxes(poses)
        bbox = random.choice(joint_bboxes)  # select a bbox randomly
        bbox_center = bbox[:2] + (bbox[2:] - bbox[:2]) / 2

        r_xy = np.random.rand(2)
        perturb = ((r_xy - 0.5) * 2 * params['center_perterb_max'])
        center = (bbox_center + perturb + 0.5).astype('i')

        crop_img = np.zeros((insize, insize, 3), 'uint8') + 127.5
        crop_mask = np.zeros((insize, insize), 'bool')

        offset = (center - (insize - 1) / 2 + 0.5).astype('i')
        offset_ = (center + (insize - 1) / 2 - (w - 1, h - 1) + 0.5).astype('i')

        x1, y1 = (center - (insize-1)/2 + 0.5).astype('i')
        x2, y2 = (center + (insize-1)/2 + 0.5).astype('i')

        x1 = max(x1, 0)
        y1 = max(y1, 0)
        x2 = min(x2, w-1)
        y2 = min(y2, h-1)

        x_from = -offset[0] if offset[0] < 0 else 0
        y_from = -offset[1] if offset[1] < 0 else 0
        x_to = insize - offset_[0] - 1 if offset_[0] >= 0 else insize - 1
        y_to = insize - offset_[1] - 1 if offset_[1] >= 0 else insize - 1

        crop_img[y_from:y_to+1, x_from:x_to+1] = img[y1:y2+1, x1:x2+1].copy()
        crop_mask[y_from:y_to+1, x_from:x_to+1] = ignore_mask[y1:y2+1, x1:x2+1].copy()

        poses[:, :, :2] -= offset
        return crop_img.astype('uint8'), crop_mask, poses

    def distort_color(self, img):
        img_max = np.broadcast_to(np.array(255, dtype=np.uint8), img.shape[:-1])
        img_min = np.zeros(img.shape[:-1], dtype=np.uint8)

        hsv_img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV).astype(np.int32)
        hsv_img[:, :, 0] = np.maximum(np.minimum(hsv_img[:, :, 0] - 10 + np.random.randint(20 + 1), img_max), img_min) # hue
        hsv_img[:, :, 1] = np.maximum(np.minimum(hsv_img[:, :, 1] - 40 + np.random.randint(80 + 1), img_max), img_min) # saturation
        hsv_img[:, :, 2] = np.maximum(np.minimum(hsv_img[:, :, 2] - 30 + np.random.randint(60 + 1), img_max), img_min) # value
        hsv_img = hsv_img.astype(np.uint8)

        distorted_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
        return distorted_img

    def flip_img(self, img, mask, poses):
        flipped_img = cv2.flip(img, 1)
        flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool')
        poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0]

        def swap_joints(poses, joint_type_, joint_type_2):
            tmp = poses[:, joint_type_].copy()
            poses[:, joint_type_] = poses[:, joint_type_2]
            poses[:, joint_type_2] = tmp

        swap_joints(poses, JointType.LeftEye, JointType.RightEye)
        swap_joints(poses, JointType.LeftEar, JointType.RightEar)
        swap_joints(poses, JointType.LeftShoulder, JointType.RightShoulder)
        swap_joints(poses, JointType.LeftElbow, JointType.RightElbow)
        swap_joints(poses, JointType.LeftHand, JointType.RightHand)
        swap_joints(poses, JointType.LeftWaist, JointType.RightWaist)
        swap_joints(poses, JointType.LeftKnee, JointType.RightKnee)
        swap_joints(poses, JointType.LeftFoot, JointType.RightFoot)
        return flipped_img, flipped_mask, poses

    def augment_data(self, img, ignore_mask, poses):
        aug_img = img.copy()
        aug_img, ignore_mask, poses = self.random_resize_img(aug_img, ignore_mask, poses)
        aug_img, ignore_mask, poses = self.random_rotate_img(aug_img, ignore_mask, poses)
        aug_img, ignore_mask, poses = self.random_crop_img(aug_img, ignore_mask, poses)
        if np.random.randint(2):
            aug_img = self.distort_color(aug_img)
        if np.random.randint(2):
            aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses)

        return aug_img, ignore_mask, poses
    # ------------------------------- end -----------------------------------


    # ------------------------------ Heatmap ------------------------------------
    # return shape: (height, width)
    def generate_gaussian_heatmap(self, shape, joint, sigma):
        x, y = joint
        grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
        grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
        grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
        gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
        return gaussian_heatmap

    def generate_heatmaps(self, img, poses, heatmap_sigma):
        heatmaps = np.zeros((0,) + img.shape[:-1])
        sum_heatmap = np.zeros(img.shape[:-1])
        for joint_index in range(len(JointType)):
            heatmap = np.zeros(img.shape[:-1])
            for pose in poses:
                if pose[joint_index, 2] > 0:
                    jointmap = self.generate_gaussian_heatmap(img.shape[:-1], pose[joint_index][:2], heatmap_sigma)
                    heatmap[jointmap > heatmap] = jointmap[jointmap > heatmap]
                    sum_heatmap[jointmap > sum_heatmap] = jointmap[jointmap > sum_heatmap]
            heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
        bg_heatmap = 1 - sum_heatmap  # background channel
        heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
        return heatmaps.astype('f')

    def generate_gaussian_heatmap_fast(self, shape, joint, sigma):
        x, y = joint
        grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
        grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
        grid_x = grid_x + 0.4375
        grid_y = grid_y + 0.4375
        grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
        gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
        return gaussian_heatmap

    def generate_heatmaps_fast(self, img, poses, heatmap_sigma):
        resize_shape = (img.shape[0] // 8, img.shape[1] // 8)
        heatmaps = np.zeros((0,) + resize_shape)
        sum_heatmap = np.zeros(resize_shape)
        for joint_index in range(len(JointType)):
            heatmap = np.zeros(resize_shape)
            for pose in poses:
                if pose[joint_index, 2] > 0:
                    jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8,
                                                                   heatmap_sigma/8)
                    index_1 = jointmap > heatmap
                    heatmap[index_1] = jointmap[index_1]
                    index_2 = jointmap > sum_heatmap
                    sum_heatmap[index_2] = jointmap[index_2]
            heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))

        bg_heatmap = 1 - sum_heatmap  # background channel
        heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
        return heatmaps.astype('f')
    # ------------------------------ end ------------------------------------

    # ------------------------------ PAF ------------------------------------
    # return shape: (2, height, width)
    def generate_constant_paf(self, shape, joint_from, joint_to, paf_width):
        if np.array_equal(joint_from, joint_to): # same joint
            return np.zeros((2,) + shape[:-1])

        joint_distance = np.linalg.norm(joint_to - joint_from)
        unit_vector = (joint_to - joint_from) / joint_distance
        rad = np.pi / 2
        # [[0, 1], [-1, 0]]
        rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
        # [[u_y], [-u_x]]
        vertical_unit_vector = np.dot(rot_matrix, unit_vector)
        grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
        grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
        horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
        horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
        vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
                                 (grid_y - joint_from[1])
        vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width  # paf_width : 8
        paf_flag = horizontal_paf_flag & vertical_paf_flag
        constant_paf = np.stack((paf_flag, paf_flag)) *\
                       np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)

        return constant_paf

    def generate_pafs(self, img, poses, paf_sigma):
        pafs = np.zeros((0,) + img.shape[:-1])

        for limb in params['limbs_point']:
            paf = np.zeros((2,) + img.shape[:-1])
            paf_flags = np.zeros(paf.shape) # for constant paf

            for pose in poses:
                joint_from, joint_to = pose[limb]
                if joint_from[2] > 0 and joint_to[2] > 0:
                    limb_paf = self.generate_constant_paf(img.shape, joint_from[:2], joint_to[:2], paf_sigma) # [2, 368, 368]
                    limb_paf_flags = limb_paf != 0
                    paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)

                    paf += limb_paf

            paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
            pafs = np.vstack((pafs, paf))
        return pafs.astype('f')

    def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width):
        if np.array_equal(joint_from, joint_to): # same joint
            return np.zeros((2,) + shape[:-1])

        joint_distance = np.linalg.norm(joint_to - joint_from)
        unit_vector = (joint_to - joint_from) / joint_distance
        rad = np.pi / 2
        # [[0, 1], [-1, 0]]
        rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
        # [[u_y], [-u_x]]
        vertical_unit_vector = np.dot(rot_matrix, unit_vector)
        grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
        grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
        grid_x = grid_x + 0.4375
        grid_y = grid_y + 0.4375
        horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
        horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
        vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
                                 (grid_y - joint_from[1])
        vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width  # paf_width : 8/8 = 1
        paf_flag = horizontal_paf_flag & vertical_paf_flag
        constant_paf = np.stack((paf_flag, paf_flag)) *\
                       np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)

        return constant_paf

    def generate_pafs_fast(self, img, poses, paf_sigma):
        resize_shape = (img.shape[0]//8, img.shape[1]//8, 3)
        pafs = np.zeros((0,) + resize_shape[:-1])

        for limb in params['limbs_point']:
            paf = np.zeros((2,) + resize_shape[:-1])
            paf_flags = np.zeros(paf.shape) # for constant paf

            for pose in poses:
                joint_from, joint_to = pose[limb]
                if joint_from[2] > 0 and joint_to[2] > 0:
                    limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368]
                    limb_paf_flags = limb_paf != 0
                    paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)

                    paf += limb_paf

            index_1 = paf_flags > 0
            paf[index_1] /= paf_flags[index_1]
            pafs = np.vstack((pafs, paf))
        return pafs.astype('f')
    # ------------------------------ end ------------------------------------

    def get_img_annotation(self, ind=None, img_id=None):
        annotations = None

        if ind is not None:
            img_id = self.imgIds[ind]
        anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)

        # annotation for that image
        if anno_ids:
            annotations_for_img = self.train.loadAnns(anno_ids)

            person_cnt = 0
            valid_annotations_for_img = []
            for annotation in annotations_for_img:
                # if too few keypoints or too small
                if annotation['num_keypoints'] >= params['min_keypoints'] and annotation['area'] > params['min_area']:
                    person_cnt += 1
                    valid_annotations_for_img.append(annotation)

            # if person annotation
            if person_cnt > 0:
                annotations = valid_annotations_for_img

        img_path = os.path.join(self.imgpath, self.train.loadImgs([img_id])[0]['file_name'])
        mask_path = os.path.join(self.maskpath, '{:012d}.png'.format(img_id))
        img = cv2.imread(img_path)
        ignore_mask = cv2.imread(mask_path, 0)
        if ignore_mask is None:
            ignore_mask = np.zeros(img.shape[:2], np.float32)
        else:
            ignore_mask[ignore_mask == 255] = 1

        if self.mode == 'eval':
            return img, img_id, annotations_for_img, ignore_mask

        return img, img_id, annotations, ignore_mask.astype('f')

    def parse_annotation(self, annotations):
        poses = np.zeros((0, len(JointType), 3), dtype=np.int32)

        for ann in annotations:
            ann_pose = np.array(ann['keypoints']).reshape(-1, 3)
            pose = np.zeros((1, len(JointType), 3), dtype=np.int32)

            # convert poses position
            for i, joint_index in enumerate(params['joint_indices']):
                pose[0][joint_index] = ann_pose[i]

            # compute neck position
            if pose[0][JointType.LeftShoulder][2] > 0 and pose[0][JointType.RightShoulder][2] > 0:
                pose[0][JointType.Neck][0] = int((pose[0][JointType.LeftShoulder][0] +
                                                  pose[0][JointType.RightShoulder][0]) / 2)
                pose[0][JointType.Neck][1] = int((pose[0][JointType.LeftShoulder][1] +
                                                  pose[0][JointType.RightShoulder][1]) / 2)
                pose[0][JointType.Neck][2] = 2

            poses = np.vstack((poses, pose))
        return poses

    def resize_output(self, input_np, map_h=46, map_w=46):
        if len(input_np.shape) == 3:
            output = np.zeros((input_np.shape[0], map_h, map_w))
            for i in range(input_np.shape[0]):
                output[i] = cv2.resize(input_np[i], (map_w, map_h))
            return output.astype('f')

        input_np = input_np.astype('f')
        output = cv2.resize(input_np, (map_h, map_w))
        return output

    def generate_labels(self, img, poses, ignore_mask):
        img, ignore_mask, poses = self.augment_data(img, ignore_mask, poses)
        resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
                                                                   shape=(self.insize, self.insize))

        # heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma'])
        # resized_heatmaps = self.resize_output(heatmaps)
        resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma'])

        # pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma'])
        # resized_pafs = self.resize_output(pafs)
        resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma'])

        ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
        resized_ignore_mask = self.resize_output(ignore_mask)


        return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask

    def preprocess(self, img):
        x_data = img.astype('f')
        x_data /= 255
        x_data -= 0.5
        x_data = x_data.transpose(2, 0, 1)
        return x_data

    def __getitem__(self, i):
        img, img_id, annotations, ignore_mask = self.get_img_annotation(ind=i)

        if self.mode in ['eval', 'val']:
            # don't need to make heatmaps/pafs
            return img, np.array([img_id])

        # if no annotations are available
        while annotations is None:
            print("none annotations", img_id)
            img_id = self.imgIds[np.random.randint(len(self))]
            img, img_id, annotations, ignore_mask = self.get_img_annotation(img_id=img_id)

        poses = self.parse_annotation(annotations)

        # TEST
        # return img, poses, ignore_mask

        resized_img, pafs, heatmaps, ignore_mask = self.generate_labels(img, poses, ignore_mask)
        resized_img = self.preprocess(resized_img)
        ignore_mask = 1. - ignore_mask

        # # TEST
        # print("Shape: ", resized_img.dtype, " ", pafs.dtype, " ", heatmaps.dtype, " ", ignore_mask.dtype)

        return resized_img, pafs, heatmaps, ignore_mask


class DistributedSampler():
    def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
        self.dataset = dataset
        self.rank = rank
        self.group_size = group_size
        self.dataset_len = len(self.dataset)
        self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
        self.total_size = self.num_samplers * self.group_size
        self.shuffle = shuffle
        self.seed = seed

    def __iter__(self):
        if self.shuffle:
            self.seed = (self.seed + 1) & 0xffffffff
            np.random.seed(self.seed)
            indices = np.random.permutation(self.dataset_len).tolist()
        else:
            indices = list(range(len(self.dataset_len)))
        indices += indices[:(self.total_size - len(indices))]
        indices = indices[self.rank::self.group_size]
        return iter(indices)

    def __len__(self):
        return self.num_samplers

def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
    #cv2.setNumThreads(0)
    val = ReadJson(jsonpath)
    dataset = txtdataset(val, imgpath, maskpath, params['insize'], mode=mode)
    sampler = DistributedSampler(dataset, rank, group_size)
    ds = de.GeneratorDataset(dataset, ['img', 'img_id'], num_parallel_workers=8, sampler=sampler)
    ds = ds.repeat(1)
    return ds


def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True,
                   multiprocessing=True, num_worker=20):

    train = ReadJson(jsonpath)
    dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
    if group_size == 1:
        de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
                                         shuffle=shuffle,
                                         num_parallel_workers=num_worker,
                                         python_multiprocessing=multiprocessing)
    else:
        de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
                                         shuffle=shuffle,
                                         num_parallel_workers=num_worker,
                                         python_multiprocessing=multiprocessing,
                                         num_shards=group_size,
                                         shard_id=rank)

    de_dataset = de_dataset.batch(batch_size=batch_size, drop_remainder=True)
    de_dataset = de_dataset.repeat(repeat_num)

    return de_dataset
